Skip to content

feature/graphical pytrees#1103

Merged
Jammy2211 merged 22 commits intomainfrom
feature/graphical_pytrees
Apr 9, 2025
Merged

feature/graphical pytrees#1103
Jammy2211 merged 22 commits intomainfrom
feature/graphical_pytrees

Conversation

@rhayes777
Copy link
Copy Markdown
Collaborator

  • move jax serialise/deserialise test util to conftest
  • pytree methods for FactorGraphModel
  • pytree methods for AnalysisFactor
  • tree flatten for LogUniformPrior
  • prior ids as children when creating pytrees

@Jammy2211
Copy link
Copy Markdown
Collaborator

After implementing the fixes in this PR: #1125

I now get this error when running the concr example: github.com/Jammy2211/concr_cosmology

2025-04-01 16:54:03,778 - autofit.non_linear.initializer - INFO - Generating initial samples of model using JAX LH Function cores
/mnt/c/Users/Jammy/Code/PyAutoJAX/PyAutoFit/autofit/messages/transform.py:174: RuntimeWarning: divide by zero encountered in reciprocal
  super().__init__(DiagonalMatrix(np.reciprocal(self.scale)))
Traceback (most recent call last):
  File "/mnt/c/Users/Jammy/Code/PyAutoJAX/PyAutoFit/autofit/non_linear/search/nest/dynesty/search/abstract.py", line 153, in _fit
    raise RuntimeError
RuntimeError

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/mnt/c/Users/Jammy/Code/PyAuto/concr_cosmology/start_here.py", line 521, in <module>
    result = search.fit(model=factor_graph.global_prior_model, analysis=factor_graph)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/c/Users/Jammy/Code/PyAutoJAX/PyAutoFit/autofit/non_linear/search/abstract_search.py", line 599, in fit
    result = self.start_resume_fit(
             ^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/c/Users/Jammy/Code/PyAutoJAX/PyAutoFit/autofit/non_linear/search/abstract_search.py", line 120, in decorated
    return func(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/c/Users/Jammy/Code/PyAutoJAX/PyAutoFit/autofit/non_linear/search/abstract_search.py", line 750, in start_resume_fit
    search_internal = self._fit(
                      ^^^^^^^^^^
  File "/mnt/c/Users/Jammy/Code/PyAutoJAX/PyAutoFit/autofit/non_linear/search/nest/dynesty/search/abstract.py", line 184, in _fit
    search_internal = self.search_internal_from(
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/c/Users/Jammy/Code/PyAutoJAX/PyAutoFit/autofit/non_linear/search/nest/dynesty/search/static.py", line 153, in search_internal_from
    live_points = self.live_points_init_from(model=model, fitness=fitness)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/c/Users/Jammy/Code/PyAutoJAX/PyAutoFit/autofit/non_linear/search/nest/dynesty/search/abstract.py", line 447, in live_points_init_from
    ) = self.initializer.samples_from_model(
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/c/Users/Jammy/Code/PyAutoJAX/PyAutoFit/autofit/non_linear/initializer.py", line 70, in samples_from_model
    return self.samples_jax(
           ^^^^^^^^^^^^^^^^^
  File "/mnt/c/Users/Jammy/Code/PyAutoJAX/PyAutoFit/autofit/non_linear/initializer.py", line 176, in samples_jax
    figure_of_merit = self.figure_of_metric((fitness, parameter_list))
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/c/Users/Jammy/Code/PyAutoJAX/PyAutoFit/autofit/non_linear/initializer.py", line 34, in figure_of_metric
    figure_of_merit = fitness(parameters=parameter_list)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jammy/venvs/PyAuto311JAX/lib/python3.11/site-packages/timeout_decorator/timeout_decorator.py", line 79, in new_function
    return function(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/c/Users/Jammy/Code/PyAutoJAX/PyAutoFit/autofit/non_linear/fitness.py", line 158, in __call__
    log_likelihood = self.log_likelihood_function(instance=instance)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/c/Users/Jammy/Code/PyAutoJAX/PyAutoFit/autofit/mapper/prior/abstract.py", line 64, in tree_unflatten
    return cls(*children)
           ^^^^^^^^^^^^^^
  File "/mnt/c/Users/Jammy/Code/PyAutoJAX/PyAutoFit/autofit/mapper/prior/uniform.py", line 48, in __init__
    lower_limit = float(lower_limit)
                  ^^^^^^^^^^^^^^^^^^
jax.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape float64[].
The problem arose with the `float` function. If trying to convert the data type of a value, try using `x.astype(float)` or `jnp.array(x, float)` instead.
The error occurred while tracing the function log_likelihood_function at /mnt/c/Users/Jammy/Code/PyAutoJAX/PyAutoFit/autofit/graphical/declarative/collection.py:87 for jit.

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

I think the problem is that JAX is trying to trace and jit functions it should not need to (e.g. which are above the scope of the log_likelihood_funtion of each Analysis.

The giveaway is this line:

  File "/mnt/c/Users/Jammy/Code/PyAutoJAX/PyAutoFit/autofit/non_linear/search/abstract_search.py", line 750, in start_resume_fit
    search_internal = self._fit(

Basically, the fit function of DynestyStatic is being JAX'd and jitted.

The fit function is a level above the Analysis and likelihood function, and should not be something we try and trace and convert with JAX.

The other clue is this message:

The error occurred while tracing the function log_likelihood_function at /mnt/c/Users/Jammy/Code/PyAutoJAX/PyAutoFit/autofit/graphical/declarative/collection.py:87 for jit.

Line 87 corresponds to this:

    def log_likelihood_function(self, instance: ModelInstance) -> float:
        """
        Compute the combined likelihood of each factor from a collection of instances
        with the same ordering as the factors.

        Parameters
        ----------
        instance
            A collection of instances, one corresponding to each factor

        Returns
        -------
        The combined likelihood of all factors
        """
        log_likelihood = 0
        for model_factor, instance_ in zip(self.model_factors, instance):
            log_likelihood += model_factor.log_likelihood_function(instance_)

        return log_likelihood

So basically the JAX tracer is being applied to something within this function it shouldnt albeit I am unclear of what...

@rhayes777
Copy link
Copy Markdown
Collaborator Author

Serialisation is occurring because jit is applied to FactorGraphModel.log_likelihood resulting in self being serialised and deserialised

@Jammy2211 Jammy2211 merged commit d2d3d08 into main Apr 9, 2025
0 of 4 checks passed
@Jammy2211 Jammy2211 deleted the feature/graphical_pytrees branch June 24, 2025 13:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants